{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inverse Compose, with Jax\n",
    "\n",
    "This runs the inverse compose benchmark with Jax, either on CPU or GPU.  We use the [jaxlie](https://brentyi.github.io/jaxlie) library for Lie Group operations.  We then compute the resulting point and jacobian of the point with respect to the pose, batched over large numbers of poses and points.  \n",
    "\n",
    "See [the paper](https://symforce.org/paper) for more information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as onp\n",
    "import jaxlie\n",
    "from jax import numpy as np\n",
    "import jax\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set to CPU\n",
    "# Comment out to use GPU/TPU\n",
    "jax.config.update(\"jax_platform_name\", \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print the platform (CPU/GPU) we're using\n",
    "jax.lib.xla_bridge.get_backend().platform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def time_func(f, key, calls):\n",
    "    start = time.perf_counter()\n",
    "    for _ in range(calls):\n",
    "        f(key)\n",
    "        _, key = jax.random.split(key)\n",
    "    end = time.perf_counter()\n",
    "    return (end - start) / calls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Helpful source and documentation references:\n",
    "# https://github.com/brentyi/jaxlie/blob/9f177f2640641c38782ec1dc07709a41ea7713ea/jaxlie/manifold/_manifold_helpers.py\n",
    "# https://brentyi.github.io/jaxlie/vmap_usage/\n",
    "\n",
    "\n",
    "@jax.jit\n",
    "@jax.vmap\n",
    "def inverse_compose(world_T_body, point):\n",
    "    # A helper function that computes the result as a function of the parameters\n",
    "    out_point = lambda parameters: jaxlie.SE3(parameters).inverse().apply(point)\n",
    "    # The jacobian of the output with respect to the parameters (not the tangent space)\n",
    "    # jacfwd is indeed better than jacrev here\n",
    "    result_D_storage = jax.jacfwd(out_point)(world_T_body.parameters())\n",
    "    # The jacobian of the parameters with respect to the tangent space\n",
    "    storage_D_tangent = jaxlie.manifold.rplus_jacobian_parameters_wrt_delta(world_T_body)\n",
    "    # Put it all together\n",
    "    J = result_D_storage @ storage_D_tangent\n",
    "    return out_point(world_T_body.parameters()), J"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(42)\n",
    "\n",
    "for N in reversed([1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7]):\n",
    "    N = int(N)\n",
    "\n",
    "    storage = np.zeros((N, 7))\n",
    "    storage = storage.at[:, 0].set(1)\n",
    "    poses = jaxlie.SE3(wxyz_xyz=storage)\n",
    "\n",
    "    points = jax.random.normal(key, (N, 3))\n",
    "    _, key = jax.random.split(key)\n",
    "\n",
    "    inverse_compose(poses, points)\n",
    "\n",
    "    def random_inverse_compose(key):\n",
    "        points_new = points.at[0, 0].set(jax.random.normal(key))\n",
    "        return inverse_compose(poses, points_new)\n",
    "\n",
    "    t = time_func(random_inverse_compose, key, 10)\n",
    "\n",
    "    def random_ninverse_compose(key):\n",
    "        points_new = points.at[0, 0].set(jax.random.normal(key))\n",
    "        return points_new\n",
    "\n",
    "    _, key = jax.random.split(key)\n",
    "    t2 = time_func(random_ninverse_compose, key, 10)\n",
    "\n",
    "    print(f\"{N:>10}   {t:10.5} {t2:10.5} {t - t2:10.5} {(t - t2) / N:10.5}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "symforce-paper",
   "language": "python",
   "name": "symforce-paper"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
